{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2325502b",
   "metadata": {},
   "source": [
    "# Customize AutoMM\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/customization.ipynb)\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/customization.ipynb)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "AutoMM has a powerful yet easy-to-use configuration design.\n",
    "This tutorial walks you through various AutoMM configurations to empower you the customization flexibility. Specifically, AutoMM configurations consist of several parts:\n",
    "\n",
    "- optim\n",
    "- env\n",
    "- model\n",
    "- data\n",
    "- distiller\n",
    "\n",
    "## Optimization\n",
    "\n",
    "### optim.lr\n",
    "\n",
    "Learning rate."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f349cfb",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.lr\": 1.0e-4})\n",
    "# set learning rate to 5.0e-4\n",
    "predictor.fit(hyperparameters={\"optim.lr\": 5.0e-4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea63ec5",
   "metadata": {},
   "source": [
    "### optim.optim_type\n",
    "\n",
    "Optimizer type.\n",
    "\n",
    "- `\"sgd\"`: stochastic gradient descent with momentum.\n",
    "- `\"adam\"`: a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments. See [this paper](https://arxiv.org/abs/1412.6980) for details.\n",
    "- `\"adamw\"`: improves adam by decoupling the weight decay from the optimization step. See [this paper](https://arxiv.org/abs/1711.05101) for details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37fcd9e2",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.optim_type\": \"adamw\"})\n",
    "# use optimizer adam\n",
    "predictor.fit(hyperparameters={\"optim.optim_type\": \"adam\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeb40726",
   "metadata": {},
   "source": [
    "### optim.weight_decay\n",
    "\n",
    "Weight decay."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "852daaee",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.weight_decay\": 1.0e-3})\n",
    "# set weight decay to 1.0e-4\n",
    "predictor.fit(hyperparameters={\"optim.weight_decay\": 1.0e-4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7006101a",
   "metadata": {},
   "source": [
    "### optim.lr_decay\n",
    "\n",
    "Later layers can have larger learning rates than the earlier layers. The last/head layer\n",
    "has the largest learning rate `optim.lr`. For a model with `n` layers, layer `i` has learning rate `optim.lr * optim.lr_decay^(n-i)`. To use one uniform learning rate, simply set the learning rate decay to `1`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adb57179",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.lr_decay\": 0.9})\n",
    "# turn off learning rate decay\n",
    "predictor.fit(hyperparameters={\"optim.lr_decay\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59632914",
   "metadata": {},
   "source": [
    "### optim.lr_mult\n",
    "\n",
    "While we are using two_stages lr choice,\n",
    "The last/head layer has the largest learning rate `optim.lr` * `optim.lr_mult`.\n",
    "And other layers has normal learning rate `optim.lr`.\n",
    "To use one uniform learning rate, simply set the learning rate multiple to `1`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1770634e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.lr_mult\": 1})\n",
    "# turn on two-stage lr for 10 times learning rate in head layer\n",
    "predictor.fit(hyperparameters={\"optim.lr_mult\": 10})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60b67198",
   "metadata": {},
   "source": [
    "### optim.lr_choice\n",
    "\n",
    "We may want different layers to have different lr,\n",
    "here we have strategy `two_stages` lr choice (see `optim.lr_mult` section for more details),\n",
    "or `layerwise_decay` lr choice (see `optim.lr_decay` section for more details).\n",
    "To use one uniform learning rate, simply set this to `\"\"`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3dd4814",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.lr_choice\": \"layerwise_decay\"})\n",
    "# turn on two-stage lr choice\n",
    "predictor.fit(hyperparameters={\"optim.lr_choice\": \"two_stages\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31355e32",
   "metadata": {},
   "source": [
    "### optim.lr_schedule\n",
    "\n",
    "Learning rate schedule.\n",
    "\n",
    "- `\"cosine_decay\"`: the decay of learning rate follows the cosine curve.\n",
    "- `\"polynomial_decay\"`: the learning rate is decayed based on polynomial functions.\n",
    "- `\"linear_decay\"`: linearly decays the learing rate."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "100a6a22",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.lr_schedule\": \"cosine_decay\"})\n",
    "# use polynomial decay\n",
    "predictor.fit(hyperparameters={\"optim.lr_schedule\": \"polynomial_decay\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ce9a4af",
   "metadata": {},
   "source": [
    "### optim.max_epochs\n",
    "\n",
    "Stop training once this number of epochs is reached."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a8c032",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.max_epochs\": 10})\n",
    "# train 20 epochs\n",
    "predictor.fit(hyperparameters={\"optim.max_epochs\": 20})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53264b7c",
   "metadata": {},
   "source": [
    "### optim.max_steps\n",
    "\n",
    "Stop training after this number of steps. Training will stop if `optim.max_steps` or `optim.max_epochs` have reached (earliest).\n",
    "By default, we disable `optim.max_steps` by setting it to -1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2482fc3e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.max_steps\": -1})\n",
    "# train 100 steps\n",
    "predictor.fit(hyperparameters={\"optim.max_steps\": 100})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd6fa991",
   "metadata": {},
   "source": [
    "### optim.warmup_steps\n",
    "\n",
    "Warm up the learning rate from 0 to `optim.lr` within this percentage of steps at the beginning of training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34d3a967",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.warmup_steps\": 0.1})\n",
    "# do learning rate warmup in the first 20% steps.\n",
    "predictor.fit(hyperparameters={\"optim.warmup_steps\": 0.2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80452db2",
   "metadata": {},
   "source": [
    "### optim.patience\n",
    "\n",
    "Stop training after this number of checks with no improvement. The check frequency is controlled by `optim.val_check_interval`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3bcd482",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.patience\": 10})\n",
    "# set patience to 5 checks\n",
    "predictor.fit(hyperparameters={\"optim.patience\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2765c653",
   "metadata": {},
   "source": [
    "### optim.val_check_interval\n",
    "\n",
    "How often within one training epoch to check the validation set. Can specify as float or int.\n",
    "\n",
    "- pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch.\n",
    "- pass an int to check after a fixed number of training batches."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ee8c226",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.val_check_interval\": 0.5})\n",
    "# check validation set 4 times during a training epoch\n",
    "predictor.fit(hyperparameters={\"optim.val_check_interval\": 0.25})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e28553d8",
   "metadata": {},
   "source": [
    "### optim.gradient_clip_algorithm\n",
    "\n",
    "The gradient clipping algorithm to use. Support to clip gradients by value or norm."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7526131f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.gradient_clip_algorithm\": \"norm\"})\n",
    "# clip gradients by value\n",
    "predictor.fit(hyperparameters={\"optim.gradient_clip_algorithm\": \"value\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c29d0454",
   "metadata": {},
   "source": [
    "### optim.gradient_clip_val\n",
    "\n",
    "Gradient clipping value, which can be the absolute value or gradient norm depending on the choice of `optim.gradient_clip_algorithm`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50e90350",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.gradient_clip_val\": 1})\n",
    "# cap the gradients to 5\n",
    "predictor.fit(hyperparameters={\"optim.gradient_clip_val\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02d07866",
   "metadata": {},
   "source": [
    "### optim.track_grad_norm\n",
    "\n",
    "Track the p-norm of gradients during training. May be set to ‘inf’ infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b60c371",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM (no tracking)\n",
    "predictor.fit(hyperparameters={\"optim.track_grad_norm\": -1})\n",
    "# track the 2-norm\n",
    "predictor.fit(hyperparameters={\"optim.track_grad_norm\": 2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abe87d32",
   "metadata": {},
   "source": [
    "### optim.log_every_n_steps\n",
    "\n",
    "How often to log within steps."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f5fe49c",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.log_every_n_steps\": 10})\n",
    "# log once every 50 steps\n",
    "predictor.fit(hyperparameters={\"optim.log_every_n_steps\": 50})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28f30c7e",
   "metadata": {},
   "source": [
    "### optim.top_k\n",
    "\n",
    "Based on the validation score, choose top k model checkpoints to do model averaging."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17258cf4",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.top_k\": 3})\n",
    "# use top 5 checkpoints\n",
    "predictor.fit(hyperparameters={\"optim.top_k\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a9f1cfc",
   "metadata": {},
   "source": [
    "### optim.top_k_average_method\n",
    "\n",
    "Use what strategy to average the top k model checkpoints.\n",
    "\n",
    "- `\"greedy_soup\"`: tries to add the checkpoints from best to worst into the averaging pool and stop if the averaged checkpoint performance decreases. See [the paper](https://arxiv.org/pdf/2203.05482.pdf) for details.\n",
    "- `\"uniform_soup\"`: averages all the top k checkpoints as the final checkpoint.\n",
    "- `\"best\"`: picks the checkpoint with the best validation performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8236ab40",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.top_k_average_method\": \"greedy_soup\"})\n",
    "# average all the top k checkpoints\n",
    "predictor.fit(hyperparameters={\"optim.top_k_average_method\": \"uniform_soup\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "652f15b4",
   "metadata": {},
   "source": [
    "### optim.peft\n",
    "\n",
    "Options for parameter-efficient finetuning. Parameter-efficient finetuning means to finetune only a small portion of parameters instead of the whole pretrained backbone.\n",
    "\n",
    "- `\"bit_fit\"`: bias parameters only. See [this paper](https://arxiv.org/pdf/2106.10199.pdf) for details.\n",
    "- `\"norm_fit\"`: normalization parameters + bias parameters. See [this paper](https://arxiv.org/pdf/2003.00152.pdf) for details.\n",
    "- `\"lora\"`: LoRA Adaptors. See [this paper](https://arxiv.org/pdf/2106.09685.pdf) for details.\n",
    "- `\"lora_bias\"`: LoRA Adaptors + bias parameters.\n",
    "- `\"lora_norm\"`: LoRA Adaptors + normalization parameters + bias parameters.\n",
    "- `\"ia3\"`: IA3 algorithm. See [this paper](https://arxiv.org/abs/2205.05638) for details.\n",
    "- `\"ia3_bias\"`: IA3 + bias parameters.\n",
    "- `\"ia3_norm\"`: IA3 + normalization parameters + bias parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a072add",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.peft\": None})\n",
    "# finetune only bias parameters\n",
    "predictor.fit(hyperparameters={\"optim.peft\": \"bit_fit\"})\n",
    "# finetune with IA3 + BitFit\n",
    "predictor.fit(hyperparameters={\"optim.peft\": \"ia3_bias\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a138ce84-1462-4d67-a82e-80e829a7a57d",
   "metadata": {},
   "source": [
    "### optim.skip_final_val\n",
    "\n",
    "Whether to skip the final validation after training is signaled to stop."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6830fa7f-d6ef-4578-9efd-16923fca0918",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optim.skip_final_val\": False})\n",
    "# skip the final validation\n",
    "predictor.fit(hyperparameters={\"optim.skip_final_val\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3cfbd99",
   "metadata": {},
   "source": [
    "## Environment\n",
    "\n",
    "### env.num_gpus\n",
    "\n",
    "The number of gpus to use. If given -1, we count the GPUs by `env.num_gpus = torch.cuda.device_count()`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6908008",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, all available gpus are used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.num_gpus\": -1})\n",
    "# use 1 gpu only\n",
    "predictor.fit(hyperparameters={\"env.num_gpus\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e3c075",
   "metadata": {},
   "source": [
    "### env.per_gpu_batch_size\n",
    "\n",
    "The batch size for each GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "692d5f4f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.per_gpu_batch_size\": 8})\n",
    "# use batch size 16 per GPU\n",
    "predictor.fit(hyperparameters={\"env.per_gpu_batch_size\": 16})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a23b8cc",
   "metadata": {},
   "source": [
    "### env.batch_size\n",
    "\n",
    "The batch size to use in each step of training. If `env.batch_size` is larger than `env.per_gpu_batch_size * env.num_gpus`, we accumulate gradients to reach the effective `env.batch_size` before performing one optimization step. The accumulation steps are calculated by `env.batch_size // (env.per_gpu_batch_size * env.num_gpus)`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14b5a0c2",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.batch_size\": 128})\n",
    "# use batch size 256\n",
    "predictor.fit(hyperparameters={\"env.batch_size\": 256})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdf820c7",
   "metadata": {},
   "source": [
    "### env.inference_batch_size_ratio\n",
    "\n",
    "Prediction or evaluation uses a larger per gpu batch size `env.per_gpu_batch_size * env.inference_batch_size_ratio`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d8a8ca",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.inference_batch_size_ratio\": 4})\n",
    "# use 2x per gpu batch size during prediction or evaluation\n",
    "predictor.fit(hyperparameters={\"env.inference_batch_size_ratio\": 2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d492098",
   "metadata": {},
   "source": [
    "### env.precision\n",
    "\n",
    "Support either double (`64`, `\"64\"`, `\"64-true\"`), float (`32`, `\"32\"`, `\"32-true\"`), bfloat16 (`\"bf16-mixed\"`, `\"bf16-true\"`), or float16 (`\"16-mixed\"`, `\"16-true\"`) precision training. For more details, refer to [here](https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision).\n",
    "\n",
    "Mixed precision like `\"16-mixed\"` is the combined use of 32 and 16 bit floating points to reduce memory footprint during model training. This can result in improved performance, achieving +3x speedups on modern GPUs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c348024",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.precision\": \"16-mixed\"})\n",
    "# use bfloat16 mixed precision\n",
    "predictor.fit(hyperparameters={\"env.precision\": \"bf16-mixed\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3aa8d934",
   "metadata": {},
   "source": [
    "### env.num_workers\n",
    "\n",
    "The number of worker processes used by the Pytorch dataloader in training. Note that more workers don't always bring speedup especially when `env.strategy = \"ddp_spawn\"`.\n",
    "For more details, see the guideline [here](https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html#distributed-data-parallel)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "789bed40",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.num_workers\": 2})\n",
    "# use 4 workers in the training dataloader\n",
    "predictor.fit(hyperparameters={\"env.num_workers\": 4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86faccf9",
   "metadata": {},
   "source": [
    "### env.num_workers_inference\n",
    "\n",
    "The number of worker processes used by the Pytorch dataloader in prediction or evaluation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4040737",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.num_workers_inference\": 2})\n",
    "# use 4 workers in the prediction/evaluation dataloader\n",
    "predictor.fit(hyperparameters={\"env.num_workers_inference\": 4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4ea42b0",
   "metadata": {},
   "source": [
    "### env.strategy\n",
    "\n",
    "Distributed training mode.\n",
    "\n",
    "- `\"dp\"`: data parallel.\n",
    "- `\"ddp\"`: distributed data parallel (python script based).\n",
    "- `\"ddp_spawn\"`: distributed data parallel (spawn based).\n",
    "\n",
    "See [here](https://lightning.ai/docs/pytorch/stable/extensions/strategy.html#selecting-a-built-in-strategy) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab1c3e6f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.strategy\": \"ddp_spawn\"})\n",
    "# use ddp during training\n",
    "predictor.fit(hyperparameters={\"env.strategy\": \"ddp\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### env.accelerator\n",
    "\n",
    "Support `\"cpu\"`, `\"gpu\"`, or `\"auto\"` (Default).\n",
    "In the auto mode, gpu has a higher priority if both cpu and gpu are available.\n",
    "\n",
    "See [here](https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.accelerator\": \"auto\"})\n",
    "# use cpu for training\n",
    "predictor.fit(hyperparameters={\"env.accelerator\": \"cpu\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### env.compile.turn_on\n",
    "\n",
    "Whether to compile Pytorch models through [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html). (Default False)\n",
    "Note that compiling model can cost some time. It is recommended for large models and long time training."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.compile.turn_on\": False})\n",
    "# turn on torch.compile\n",
    "predictor.fit(hyperparameters={\"env.compile.turn_on\": True})\n",
    "```\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### env.compile.mode\n",
    "\n",
    "Can be either `“default”`, `“reduce-overhead”`, `“max-autotune”` or `“max-autotune-no-cudagraphs”`.\n",
    "For details, refer to [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.compile.mode\": \"default\"})\n",
    "# reduces the overhead of python with CUDA graphs, useful for small batches.\n",
    "predictor.fit(hyperparameters={\"env.compile.mode\": “reduce-overhead”})\n",
    "```\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### env.compile.dynamic\n",
    "\n",
    "Whether to use dynamic shape tracing (Default True). For details, refer to [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.compile.dynamic\": True})\n",
    "# assumes a static input shape across mini-batches.\n",
    "predictor.fit(hyperparameters={\"env.compile.dynamic\": False})\n",
    "```\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### env.compile.backend\n",
    "\n",
    "Backend to be used when compiling the model. For details, refer to [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.compile.backend\": \"inductor\"})\n",
    "```\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "id": "f2b4a051",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "### model.names\n",
    "\n",
    "Choose what types of models to use.\n",
    "\n",
    "- `\"hf_text\"`: the pretrained text models from [Huggingface](https://huggingface.co/).\n",
    "- `\"timm_image\"`: the pretrained image models from [TIMM](https://github.com/rwightman/pytorch-image-models/tree/master/timm/models).\n",
    "- `\"clip\"`: the pretrained CLIP models.\n",
    "- `\"categorical_mlp\"`: MLP for categorical data.\n",
    "- `\"numerical_mlp\"`: MLP for numerical data.\n",
    "- `\"ft_transformer\"`: [FT-Transformer](https://arxiv.org/pdf/2106.11959.pdf) for tabular (categorical and numerical) data.\n",
    "- `\"fusion_mlp\"`: MLP-based fusion for features from multiple backbones.\n",
    "- `\"fusion_transformer\"`: transformer-based fusion for features from multiple backbones.\n",
    "- `\"sam\"`: the pretrained Segment Anything Model from [Huggingface](https://huggingface.co/).\n",
    "\n",
    "If no data of one modality is detected, the related model types will be automatically removed in training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bc0c2e5",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"hf_text\", \"timm_image\", \"clip\", \"categorical_mlp\", \"numerical_mlp\", \"fusion_mlp\"]})\n",
    "# use only text models\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"hf_text\"]})\n",
    "# use only image models\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"timm_image\"]})\n",
    "# use only clip models\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"clip\"]})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d1c833",
   "metadata": {},
   "source": [
    "### model.hf_text.checkpoint_name\n",
    "\n",
    "Specify a text backbone supported by the Hugginface [AutoModel](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodel)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27360756",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.checkpoint_name\": \"google/electra-base-discriminator\"})\n",
    "# choose roberta base\n",
    "predictor.fit(hyperparameters={\"model.hf_text.checkpoint_name\": \"roberta-base\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bff5c1bd",
   "metadata": {},
   "source": [
    "### model.hf_text.pooling_mode\n",
    "\n",
    "The feature pooling mode for transformer architectures.\n",
    "\n",
    "- `cls`: uses the cls feature vector to represent a sentence.\n",
    "- `mean`: averages all the token feature vectors to represent a sentence."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1359b199",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.pooling_mode\": \"cls\"})\n",
    "# using the mean pooling\n",
    "predictor.fit(hyperparameters={\"model.hf_text.pooling_mode\": \"mean\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87a0ac4c",
   "metadata": {},
   "source": [
    "### model.hf_text.tokenizer_name\n",
    "\n",
    "Choose the text tokenizer. It is recommended to use the default auto tokenizer.\n",
    "\n",
    "- `hf_auto`: the [Huggingface auto tokenizer](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer).\n",
    "- `bert`: the [BERT tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/bert#transformers.BertTokenizer).\n",
    "- `electra`: the [ELECTRA tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/electra#transformers.ElectraTokenizer).\n",
    "- `clip`: the [CLIP tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/clip#transformers.CLIPTokenizer)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c50a4d84",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.tokenizer_name\": \"hf_auto\"})\n",
    "# using the tokenizer of the ELECTRA model\n",
    "predictor.fit(hyperparameters={\"model.hf_text.tokenizer_name\": \"electra\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8372c024",
   "metadata": {},
   "source": [
    "### model.hf_text.max_text_len\n",
    "\n",
    "Set the maximum text length. Different models may allow different maximum lengths. If `model.hf_text.max_text_len` > 0, we choose the minimum between `model.hf_text.max_text_len` and the maximum length allowed by the model. Setting `model.hf_text.max_text_len` <= 0 would use the model's maximum length."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1db84e23",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.max_text_len\": 512})\n",
    "# set to use the length allowed by the tokenizer.\n",
    "predictor.fit(hyperparameters={\"model.hf_text.max_text_len\": -1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67a57f56",
   "metadata": {},
   "source": [
    "### model.hf_text.insert_sep\n",
    "\n",
    "Whether to insert the SEP token between texts from different columns of a dataframe."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c8f6b9",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.insert_sep\": True})\n",
    "# use no SEP token.\n",
    "predictor.fit(hyperparameters={\"model.hf_text.insert_sep\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "525da692",
   "metadata": {},
   "source": [
    "### model.hf_text.text_segment_num\n",
    "\n",
    "How many text segments are used in a token sequence. Each text segment has one [token type ID](https://huggingface.co/transformers/v2.11.0/glossary.html#token-type-ids). We choose the minimum between `model.hf_text.text_segment_num` and the default used by the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "519d778a",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_segment_num\": 2})\n",
    "# use 1 text segment\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_segment_num\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffe782fd",
   "metadata": {},
   "source": [
    "### model.hf_text.stochastic_chunk\n",
    "\n",
    "Whether to randomly cut a text chunk if a sample's text token number is larger than `model.hf_text.max_text_len`. If False, cut a token sequence from index 0 to the maximum allowed length. Otherwise, randomly sample a start index to cut a text chunk."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71f3f97e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.stochastic_chunk\": False})\n",
    "# select a stochastic text chunk if a text sequence is over-long\n",
    "predictor.fit(hyperparameters={\"model.hf_text.stochastic_chunk\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d63c524",
   "metadata": {},
   "source": [
    "### model.hf_text.text_aug_detect_length\n",
    "\n",
    "Perform text augmentation only when the text token number is no less than `model.hf_text.text_aug_detect_length`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2538fbd3",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_aug_detect_length\": 10})\n",
    "# Allow text augmentation for texts whose token number is no less than 5\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_aug_detect_length\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4755a0c",
   "metadata": {},
   "source": [
    "### model.hf_text.text_trivial_aug_maxscale\n",
    "\n",
    "Set the maximum percentage of text tokens to conduct data augmentation. For each text token sequence, we randomly sample a percentage in [0, `model.hf_text.text_trivial_aug_maxscale`] and one operation from four trivial augmentations, including synonym replacement, random word swap, random word deletion, and random punctuation insertion, to do text augmentation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b44d07f",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, AutoMM doesn't do text augmentation\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_trivial_aug_maxscale\": 0})\n",
    "# Enable trivial augmentation by setting the max scale to 0.1\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_trivial_aug_maxscale\": 0.1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5019b9f4",
   "metadata": {},
   "source": [
    "### model.hf_text.gradient_checkpointing\n",
    "\n",
    "Whether to turn on gradient checkpointing to reduce the memory consumption for calculating gradients. For more about gradient checkpointing, feel free to refer to [relevant tutorials](https://github.com/cybertronai/gradient-checkpointing)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38476035",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, AutoMM doesn't turn on gradient checkpointing\n",
    "predictor.fit(hyperparameters={\"model.hf_text.gradient_checkpointing\": False})\n",
    "# Turn on gradient checkpointing\n",
    "predictor.fit(hyperparameters={\"model.hf_text.gradient_checkpointing\": True})\n",
    "```\n"
   ]
  },
  {
    "cell_type": "markdown",
    "id": "9b29a7b",
    "metadata": {},
    "source": [
     "### model.ft_transformer.checkpoint_name\n",
     "\n",
     "Using local pre-trained weights or link to pre-trained weights to initialize ft_transformer backbone."
    ]
   },
   {
    "cell_type": "markdown",
    "id": "6s39392",
    "metadata": {},
    "source": [
     "```\n",
     "# by default, AutoMM doesn't use pre-trained weights\n",
     "predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": None})\n",
     "# initialize the ft_transformer backbone from local checkpoint\n",
     "predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": 'my_checkpoint.ckpt'})\n",
     "# initialize the ft_transformer backbone from url of checkpoint\n",
     "predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": 'https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt'})\n",
     "```"
    ]
   },
  {
   "cell_type": "markdown",
   "id": "9b3d3a7b",
   "metadata": {},
   "source": [
    "### model.ft_transformer.num_blocks\n",
    "\n",
    "Number of transformer blocks in the ft_transformer backbone."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "642d7392",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.num_blocks\": 3})\n",
    "# increase the number of blocks to 5 in ft_transformer\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.num_blocks\": 5})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5340d090",
   "metadata": {},
   "source": [
    "### model.ft_transformer.token_dim\n",
    "\n",
    "The dimension of tokens after categorical and numerical tokenizer in ft_transformer."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "780bddb0",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.token_dim\": 192})\n",
    "# increase the token dimension to 256 in ft_transformer\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.token_dim\": 256})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87348422",
   "metadata": {},
   "source": [
    "### model.ft_transformer.hidden_size\n",
    "\n",
    "The model embedding dimension of ft_transformer backbone."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "996c3a0e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.hidden_size\": 192})\n",
    "# increase the model embedding dimension to 256 in ft_transformer\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.hidden_size\": 256})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6523568",
   "metadata": {},
   "source": [
    "### model.ft_transformer.ffn_hidden_size\n",
    "\n",
    "The hidden layer dimension of the FFN (Feed-Forward) layer in [ft_transformer blocks](https://arxiv.org/pdf/2106.11959v5.pdf). In the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) paper, the hidden layer dimension in FFN is set to $4\\times$ of the model hidden size. Here, we set it equal to the model hidden size by default."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e448822",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.ffn_hidden_size\": 192})\n",
    "# increase the FFN hidden layer dimension to 256 in ft_transformer\n",
    "predictor.fit(hyperparameters={\"model.ft_transformer.ffn_hidden_size\": 256})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec6d7d31",
   "metadata": {},
   "source": [
    "### model.timm_image.checkpoint_name\n",
    "\n",
    "Select an image backbone from [TIMM](https://github.com/rwightman/pytorch-image-models/tree/master/timm/models)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20aa69bb",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.timm_image.checkpoint_name\": \"swin_base_patch4_window7_224\"})\n",
    "# choose a vit base\n",
    "predictor.fit(hyperparameters={\"model.timm_image.checkpoint_name\": \"vit_base_patch32_224\"})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### model.timm_image.train_transforms\n",
    "\n",
    "Augment images in training. Support passing a list of supported strings chosen from (`resize_to_square`, `resize_shorter_side`, `center_crop`, `random_resize_crop`, `random_horizontal_flip`, `random_vertical_flip`, `color_jitter`, `affine`, `randaug`, `trivial_augment`), or a list of callable and pickle-able transform objects. For example, you use the torchvision transforms (https://pytorch.org/vision/stable/transforms.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [\"resize_shorter_side\", \"center_crop\", \"trivial_augment\"]})\n",
    "# use random resize crop and random horizontal flip\n",
    "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [\"random_resize_crop\", \"random_horizontal_flip\"]})\n",
    "# or use a list of callable and pickle-able objects, e.g., torchvision transforms\n",
    "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [torchvision.transforms.RandomResizedCrop(224), torchvision.transforms.RandomHorizontalFlip()]})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### model.timm_image.val_transforms\n",
    "\n",
    "Transform images in validation/test/deployment. Similar to `model.timm_image.train_transforms`, support a list of strings or callable and pickle-able objects to transform images."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [\"resize_shorter_side\", \"center_crop\"]})\n",
    "# resize image to square\n",
    "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [\"resize_to_square\"]})\n",
    "# or use a list of callable and pickle-able objects, e.g., torchvision transforms\n",
    "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [torchvision.transforms.Resize((224, 224)]})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db39fec1",
   "metadata": {},
   "source": [
    "### model.mmdet_image.checkpoint_name\n",
    "\n",
    "Specify a MMDetection model supported by [MMDetection](https://mmdetection.readthedocs.io/en/latest/user_guides/inference.html). Please use \"yolox_nano\", \"yolox_tiny\", \"yolox_s\", \"yolox_m\", \"yolox_l\", or \"yolox_x\" to run our modified YOLOX models that are compatible to Autogluon."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf07cc08",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.checkpoint_name\": \"yolov3_mobilenetv2_8xb24-320-300e_coco\"})\n",
    "# choose YOLOX-L\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.checkpoint_name\": \"yolox_l\"})\n",
    "# choose DINO-SwinL\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.checkpoint_name\": \"dino-5scale_swin-l_8xb2-36e_coco\"})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64f465e9",
   "metadata": {},
   "source": [
    "### model.mmdet_image.output_bbox_format\n",
    "\n",
    "The output bounding box format:\n",
    "\n",
    "- `\"xyxy\"`: Output [x1,y1,x2,y2]. Bounding boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is our default output format.\n",
    "- `\"xywh\"`: Output [x1,y1,w,h]. Bounding boxes are represented via corner, width and height, x1, y1 being top left, w, h being width and height."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87be5d56",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.output_bbox_format\": \"xyxy\"})\n",
    "# choose xywh output format\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.output_bbox_format\": \"xywh\"})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30c7ec4d",
   "metadata": {},
   "source": [
    "### model.mmdet_image.frozen_layers\n",
    "\n",
    "The layers to be frozen. All layers that contain such substring will be frozen."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM, freeze nothing and update all parameters\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.frozen_layers\": []})\n",
    "# freeze the model's backbone\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.frozen_layers\": [\"backbone\"]})\n",
    "# freeze the model's backbone and neck\n",
    "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.frozen_layers\": [\"backbone\", \"neck\"]})\n",
    "```"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "id": "56ee3f92",
   "metadata": {},
   "source": [
    "### model.sam.checkpoint_name\n",
    "\n",
    "Specify a SAM backbone supported by the Hugginface [SAM](https://huggingface.co/docs/transformers/main/model_doc/sam)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccbd46cb",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.checkpoint_name\": \"facebook/sam-vit-huge\"})\n",
    "# choose SAM-Large\n",
    "predictor.fit(hyperparameters={\"model.sam.checkpoint_name\": \"facebook/sam-vit-large\"})\n",
    "# choose SAM-Base\n",
    "predictor.fit(hyperparameters={\"model.sam.checkpoint_name\": \"facebook/sam-vit-base\"})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b106e2c8",
   "metadata": {},
   "source": [
    "### model.sam.train_transforms\n",
    "\n",
    "Augment images in training. Support passing `random_horizontal_flip` currently."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b22638",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.train_transforms\": [\"random_horizontal_flip\"]})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc1433cd",
   "metadata": {},
   "source": [
    "### model.sam.img_transforms\n",
    "\n",
    "Process input images for semantic segmentation. Support passing `resize_to_square` currently."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ffc0e05",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.img_transforms\": [\"resize_to_square\"]})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78130e5c",
   "metadata": {},
   "source": [
    "### model.sam.gt_transforms\n",
    "\n",
    "Process ground truth masks for semantic segmentation. Support passing `resize_gt_to_square` currently."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5964c3b5",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.gt_transforms\": [\"resize_gt_to_square\"]})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ca01fcf",
   "metadata": {},
   "source": [
    "### model.sam.frozen_layers\n",
    "\n",
    "Freeze the modules of SAM in training. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4377a293",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.frozen_layers\": [\"mask_decoder.iou_prediction_head\", \"prompt_encoder\"]})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce7a5e4c",
   "metadata": {},
   "source": [
    "### model.sam.num_mask_tokens\n",
    "\n",
    "The number of mask proposals of SAM's mask decoder."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9be77770",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.num_mask_tokens\": 1})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3ef9e73",
   "metadata": {},
   "source": [
    "### model.sam.ignore_label\n",
    "\n",
    "Specifies a target value that is ignored and does not contribute to the training loss and metric calculation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d2e0373",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.sam.ignore_label\": 255})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "690a7766",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "### data.image.missing_value_strategy\n",
    "\n",
    "How to deal with missing images, opening which fails.\n",
    "\n",
    "- `\"skip\"`: skip a sample with missing images.\n",
    "- `\"zero\"`: use zero image to replace a missing image."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed5ad640",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.image.missing_value_strategy\": \"zero\"})\n",
    "# skip the image\n",
    "predictor.fit(hyperparameters={\"data.image.missing_value_strategy\": \"skip\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b5a689b",
   "metadata": {},
   "source": [
    "### data.text.normalize_text\n",
    "Whether to normalize text with encoding problems. If True, TextProcessor will run through a series of encoding and decoding for text normalization. Please refer to the [Example](https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize) of Kaggle competition for applying text normalization."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab6c46ad",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.text.normalize_text\": False})\n",
    "# turn on text normalization\n",
    "predictor.fit(hyperparameters={\"data.text.normalize_text\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "177eb155",
   "metadata": {},
   "source": [
    "### data.categorical.convert_to_text\n",
    "\n",
    "Whether to treat categorical data as text. If True, no categorical models, e.g., `\"categorical_mlp\"` and `\"categorical_transformer\"`, would be used."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "267ad0a9",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.categorical.convert_to_text\": False})\n",
    "# turn on the conversion\n",
    "predictor.fit(hyperparameters={\"data.categorical.convert_to_text\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bf8d9e6",
   "metadata": {},
   "source": [
    "### data.numerical.convert_to_text\n",
    "\n",
    "Whether to convert numerical data to text. If True, no numerical models e.g., `\"numerical_mlp\"` and `\"numerical_transformer\"`, would be used."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f158d9a0",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.numerical.convert_to_text\": False})\n",
    "# turn on the conversion\n",
    "predictor.fit(hyperparameters={\"data.numerical.convert_to_text\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "daaa41c9",
   "metadata": {},
   "source": [
    "### data.numerical.scaler_with_mean\n",
    "\n",
    "If True, center the numerical data (not including the numerical labels) before scaling."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "984abb92",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_mean\": True})\n",
    "# turn off centering\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_mean\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "589677a4",
   "metadata": {},
   "source": [
    "### data.numerical.scaler_with_std\n",
    "\n",
    "If True, scale the numerical data (not including the numerical labels) to unit variance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cfca7db",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_std\": True})\n",
    "# turn off scaling\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_std\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9241360b",
   "metadata": {},
   "source": [
    "### data.label.numerical_preprocessing\n",
    "\n",
    "How to process the numerical labels in regression tasks.\n",
    "\n",
    "- `\"standardscaler\"`: standardizes numerical labels by removing the mean and scaling to unit variance.\n",
    "- `\"minmaxscaler\"`: transforms numerical labels by scaling each feature to range (0, 1)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bea7d018",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.label.numerical_preprocessing\": \"standardscaler\"})\n",
    "# scale numerical labels to (0, 1)\n",
    "predictor.fit(hyperparameters={\"data.label.numerical_preprocessing\": \"minmaxscaler\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "103f689a",
   "metadata": {},
   "source": [
    "### data.pos_label\n",
    "\n",
    "The positive label in a binary classification task. Users need to specify this label to properly use some metrics, e.g., roc_auc, average_precision, and f1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a14b1af",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.pos_label\": None})\n",
    "# assume the labels are [\"changed\", \"not changed\"] and \"changed\" is the positive label\n",
    "predictor.fit(hyperparameters={\"data.pos_label\": \"changed\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### data.column_features_pooling_mode\n",
    "\n",
    "How to aggregate column features into one feature vector for a dataframe with multiple feature columns. Currently, it works only for `few_shot_classification`.\n",
    "- `\"concat\"`: Concatenate features of different columns into a long feature vector.\n",
    "- `\"mean\"`: Average the column features so that the feature dimension doesn't increase along with the column number."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.column_features_pooling_mode\": \"concat\"})\n",
    "# use the mean pooling\n",
    "predictor.fit(hyperparameters={\"data.column_features_pooling_mode\": \"mean\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3e3242f",
   "metadata": {},
   "source": [
    "### data.mixup.turn_on\n",
    "\n",
    "If True, use Mixup in training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0161ba46",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_on\": False})\n",
    "# turn on Mixup\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_on\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd97b924",
   "metadata": {},
   "source": [
    "### data.mixup.mixup_alpha\n",
    "\n",
    "Mixup alpha value. Mixup is active if `data.mixup.mixup_alpha` > 0."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd9bc14b",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.mixup_alpha\": 0.8})\n",
    "# set it to 1.0 to turn off Mixup\n",
    "predictor.fit(hyperparameters={\"data.mixup.mixup_alpha\": 1.0})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b48c9408",
   "metadata": {},
   "source": [
    "### data.mixup.cutmix_alpha\n",
    "\n",
    "Cutmix alpha value. Cutmix is active if `data.mixup.cutmix_alpha` > 0."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fc9b53a",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, Cutmix is turned off by using alpha 1.0\n",
    "predictor.fit(hyperparameters={\"data.mixup.cutmix_alpha\": 1.0})\n",
    "# turn it on by choosing a number in range (0, 1)\n",
    "predictor.fit(hyperparameters={\"data.mixup.cutmix_alpha\": 0.8})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3a58751",
   "metadata": {},
   "source": [
    "### data.mixup.prob\n",
    "\n",
    "The probability of conducting Mixup or Cutmix if enabled."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "738cdcc7",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.prob\": 1.0})\n",
    "# set probability to 0.5\n",
    "predictor.fit(hyperparameters={\"data.mixup.prob\": 0.5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5991c094",
   "metadata": {},
   "source": [
    "### data.mixup.switch_prob\n",
    "\n",
    "The probability of switching to Cutmix instead of Mixup when both are active."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d24393ef",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.switch_prob\": 0.5})\n",
    "# set probability to 0.7\n",
    "predictor.fit(hyperparameters={\"data.mixup.switch_prob\": 0.7})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab459677",
   "metadata": {},
   "source": [
    "### data.mixup.mode\n",
    "\n",
    "How to apply Mixup or Cutmix params (per `\"batch\"`, `\"pair\"` (pair of elements), `\"elem\"` (element)).\n",
    "See [here](https://github.com/rwightman/pytorch-image-models/blob/d30685c283137b4b91ea43c4e595c964cd2cb6f0/timm/data/mixup.py#L211-L216) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ada57733",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.mode\": \"batch\"})\n",
    "# use \"pair\"\n",
    "predictor.fit(hyperparameters={\"data.mixup.mode\": \"pair\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737d454d",
   "metadata": {},
   "source": [
    "### data.mixup.label_smoothing\n",
    "\n",
    "Apply label smoothing to the mixed label tensors."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40f1d216",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.label_smoothing\": 0.1})\n",
    "# set it to 0.2\n",
    "predictor.fit(hyperparameters={\"data.mixup.label_smoothing\": 0.2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73c8a29d",
   "metadata": {},
   "source": [
    "### data.mixup.turn_off_epoch\n",
    "\n",
    "Stop Mixup or Cutmix after reaching this number of epochs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0c3715f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_off_epoch\": 5})\n",
    "# turn off mixup after 7 epochs\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_off_epoch\": 7})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fb33f07",
   "metadata": {},
   "source": [
    "## Distiller\n",
    "\n",
    "### distiller.soft_label_loss_type\n",
    "\n",
    "What loss to compute when using teacher's output (logits) to supervise student's."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3ca2c3d",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_loss_type\": \"cross_entropy\"})\n",
    "# default used by AutoMM for regression\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_loss_type\": \"mse\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c91287a0",
   "metadata": {},
   "source": [
    "### distiller.temperature\n",
    "\n",
    "Before computing the soft label loss, scale the teacher and student logits with it (teacher_logits / temperature, student_logits / temperature)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f67e3e1",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.temperature\": 5})\n",
    "# set temperature to 1\n",
    "predictor.fit(hyperparameters={\"distiller.temperature\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f95727c",
   "metadata": {},
   "source": [
    "### distiller.hard_label_weight\n",
    "\n",
    "Scale the student's hard label (groundtruth) loss with this weight (hard_label_loss \\* hard_label_weight)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f5d5eca",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.hard_label_weight\": 0.2})\n",
    "# set not to scale the hard label loss\n",
    "predictor.fit(hyperparameters={\"distiller.hard_label_weight\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ebc0b75",
   "metadata": {},
   "source": [
    "### distiller.soft_label_weight\n",
    "\n",
    "Scale the student's soft label (teacher's output) loss with this weight (soft_label_loss \\* soft_label_weight)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b3b90c2",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_weight\": 50})\n",
    "# set not to scale the soft label loss\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_weight\": 1})\n",
    "```\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
